#!/usr/bin/env python
## This file is part of biopy.
## Copyright (C) 2010 Joseph Heled
## Author: Joseph Heled <jheled@gmail.com>
## See the files gpl.txt and lgpl.txt for copying conditions.

from __future__ import division

import optparse, sys, os.path

from math import log
from numpy import mean, median
from scipy.optimize import fmin_powell

from biopy.genericutils import fileFromName

from biopy import INexus, beastLogHelper, demographic
from biopy.treeutils import toNewick, countNexusTrees, getTreeClades, \
     addAttributes, attributesVarName
from biopy.bayesianStats import hpd

parser = optparse.OptionParser(usage =
                               """ %prog [OPTIONS] tree posterior-trees.nexus

	Annotate tree with posterior estimate of population sizes. Prints NEWICK
	tree to standard output. On UNIX it is easy to get the tree for the
	first argument from biopy itself, e.g (assuming the bash shell)
        starbeast_posterior_popsizes $(summary_tree trees.nexus) trees.nexus.""")

parser.add_option("-b", "--burnin", dest="burnin",
                  help="Burn-in amount (percent, default %default)", default = "10")

parser.add_option("-e", "--every", dest="every", metavar="E",
                  help="""thin out - take one tree for every E. Especially
                  useful if you run out of memory (default all,
                  i.e. %default)""", default = "1")  

## parser.add_option("-o", "--optimizeTaxa", dest="optTaxa",
##                   help="Optimize population size at time 0 - default is to fix
##                   it at mean posterior value.",
##                   action="store_true", default = False)

parser.add_option("-p", "--progress", dest="progress",
                  help="Print out progress messages to terminal (standard error)",
                  action="store_true", default = False)

options, args = parser.parse_args()

progress = options.progress
every = int(options.every)
burnIn = float(options.burnin)/100.0

if len(args) != 2 :
  parser.print_help(sys.stderr)
  sys.exit(1)
  
treeText, nexusTreesFileName = args[:2]

target = INexus.Tree(treeText)

targetClades = getTreeClades(target, True)

# A mapping with target clades as keys
cladesDict = dict()
cladeNode = dict()

for c,n in targetClades:
  cladesDict[frozenset(c)] = []
  cladeNode[frozenset(c)] = n

noTaxaOpt = True ; # not options.optTaxa
if noTaxaOpt :
  taxaDict = dict([(tx,[]) for tx in target.get_taxa()])

try :
  nexFile = fileFromName(nexusTreesFileName)
except Exception,e:
  # report error
  print >> sys.stderr, "Error:", e.message
  sys.exit(1)

if progress:
  print >> sys.stderr, "counting trees ...,",
  
nTrees = countNexusTrees(nexusTreesFileName)

nexusReader = INexus.INexus()

if progress:
  print >> sys.stderr, "reading %d trees ...," % int((nTrees * (1-burnIn) / every)),

# Root 'branches'
rBranches = []
constRoot = True
for tree in nexusReader.read(nexFile, slice(int(burnIn*nTrees), -1, every)):
  has = beastLogHelper.setDemographics([tree])
  if not has[0] :
    print >> sys.stderr, "No population size data in tree(s), abort"
    sys.exit(1)
    
  clades = getTreeClades(tree, True)
    
  for c,node in clades:
    cladeSet = frozenset(c)
    e = cladesDict.get(cladeSet)
    if e is not None :
      # A tree clade present in target
      d = node.data.demographic
      l = d.naturalLimit()
      p0 = d.population(0)
      if l is None :
        ipop = 1/p0
      else :
        if node.id == tree.root :
          rBranches.append(l)
          constRoot = constRoot and (d.population(0) == d.population(l))
        # total 1/N over branch
        p1 = d.population(l)
        ipop = (p0,p1,node.data.demographic.integrate(l),l)

      e.append(ipop)
      if noTaxaOpt and len(cladeSet) == 1 :
        taxaDict[iter(cladeSet).next()].append(p0)
  
constDemos = all([tree.node(x).data.demographic.naturalLimit() is None
                  for x in tree.all_ids()])

if constDemos :
  for ni in target.all_ids() :
    n = target.node(ni)
    vals = []
    for c,n1 in targetClades:
      if n == n1 :
        vals = cladesDict[frozenset(c)]      
        break
    if len(vals) :
      pop = 1/mean(vals)
      h95 = hpd(vals, 0.95)
      addAttributes(n, {
        'dmv' : '%g' % pop,
        'dmv95' : ('{%g,%g}' % tuple([1/x for x in reversed(h95)]))
        })
else :
  if progress:
    print >>  sys.stderr, "optimizing ...," ,

  tids = target.all_ids()
  tax = target.get_terminals()

  me = dict(zip(tids, range(len(tids))))
  ms = dict(zip(tax, range(len(tids), len(tids)+len(tax))))

  popIndices = []
  for ni in target.all_ids() :
    n = target.node(ni)
    pp = None
    if n.data.taxon :
      s = [ms[ni],None]
      if noTaxaOpt :
        pp = median(taxaDict[n.data.taxon])
    else :
      s = [me[x] for x in n.succ]
      
    branch = n.data.branchlength if ni != target.root else mean(rBranches)

    vals = []
    for c,n1 in targetClades:
      if n == n1 :
        vals = cladesDict[frozenset(c)]      
        break

    if len(vals) > 0 :
      pb,pe = [],[]
      ccc=0
      for p0,p1,ipop,l in vals:
        # we want p_b,p_e which integrate (on average) over br to ipop and
        # have the same ratio, p_b/p_e == d.p(0)/d.p(l)
        #
        # if N(x) = p_0 *(1-x/h) + p_1 * (x/h), then
        # integral( 1/N(x), x=[0,h] ) = h lg(p_0/p_1)/(p_0-p_1)
        # so let alpha = p_0/p_1 and solve
        # p'_1 = h' log(alpha)/((alpha-1)*V), p''_0 = alpha * p'_1
        # would integrate to V

        alpha = p0/p1
        if alpha == 1.0 :
          p_e, p_b = p1, p0
        else :
          p_e = (log(alpha)) / ((alpha-1) * (ipop/branch))
          p_b = alpha * p_e
        pb.append(1/p_b)
        pe.append(1/p_e)
        if 0 and ccc < 10:
          print toNewick(target, ni, topologyOnly=1), p0,p1,l,'=', p_b, p_e, branch
          ccc += 1

      if 0 :  
        print toNewick(target, ni, topologyOnly=1),\
              1/mean(pb),1/mean(pe),\
              mean([1/x for x in pb]),mean([1/x for x in pe]), \
              mean([l for p0,p1,ipop,l in vals]), branch, pb[:10], pe[:10]
      v = (len(vals), sum(pb), sum([x**2 for x in pb]),
           hpd([1/x for x in pb], 0.95),
           sum(pe), sum([x**2 for x in pe]))
    else :
      v = (0,)*6
    # e is None for root node in constant root model
    e = me[ni] if not (ni == target.root and constRoot) else None
    popIndices.append((s, e, branch, v, n, pp))

  def err(p, pasn) :
    pops = [abs(x) for x in p] 
    err = 0.0
    for (s1,s2),e,branch,(n,mnb,smnb,cib,mne,smne),nd,pp in pasn :
      if e is None:
        ps = 1/(pops[s1] + pops[s2])
        xerr2 = n*ps**2 - 2*ps*mnb + smnb
        err += xerr2 * branch
      else :
        ps,pe = pops[s1] + (pops[s2] if s2 is not None else 0), pops[e]
        if pp is not None :
          ps = pp
        # 0 branch length (probably from optimizer)
        if branch > 0 :
          ps,pe = 1/ps,1/pe
          xerr2 = (n * ps**2 - 2*ps*mnb + smnb) +  (n * pe**2 - 2*pe*mne + smne)
          err += xerr2 * branch
    return err

  nTaxa = len(tax)
  nPopParams = 3*nTaxa - 1 # 3n-1
  pops = [1]*nPopParams

  oo = fmin_powell(lambda x : err(x, popIndices), pops, disp=0, full_output=1)
  pops = [abs(x) for x in oo[0]]

  for (s1,s2),e,branch,(n,mnb,smnb,cib,mne,smne),nd,pp in popIndices :
    if e is None :
      # const root
      ps = pe = pops[s1] + pops[s2]
    else :
      ps,pe = pops[s1] + (pops[s2] if s2 is not None else 0), pops[e]
      if pp is not None :
        ps = pp
      # take mean for zero branch 
      if branch == 0 :
        ps = pe = n/mn
        
      if 0 :
        d = demographic.LinearPiecewisePopulation([ps,pe], [branch])
        ipop = d.integrate(branch)/branch
        print (s1,s2),e, ipop, mn/n, (ps - d.population(0), ps, d.population(0))

    addAttributes(nd, {'dmv' : ('{%g,%g}' % (ps,pe)), 'dmt' : str(branch),
                       # 'dmv95' :  ('{%g,%g}' % (ps95,pe95)),
                       'dmv95' :  ('{%g,%g}' % cib)})

if progress:
  print >>  sys.stderr, "done." 

print toNewick(target, attributes = attributesVarName)  

